from data import Data
from environment import Environment
import numpy as np
import pickle

# finds the rate of high confidence for each of the subjects and each of the sessions/directions


def latex_string(conf_array, name):
    latex_str = name + " & "
    for conf in conf_array:
        latex_str += f"{round(conf*100, 2)} & "
    latex_str = latex_str[:-2]
    latex_str += "\\\ "
    print(latex_str)


if __name__ == '__main__':

    # loads and cleans data
    data = Data('data/rawChoiceData.txt')
    trials_sub = data.split_subject(data.original_trials)
    num_sub = len(trials_sub.keys())
    num_session = 7
    num_models = 4
    beta = 1

    value_confs = np.zeros(num_sub)
    value_confs_pred = np.zeros((num_sub, num_models))
    params = np.load("results/conf/params/conf_params.pkl", allow_pickle=True)
    sigzs = np.mean(params.sigzs, axis=0)
    sigzs_subs = np.mean(params.sigzs_subs, axis=0)
    cc_pai = np.mean(params.cc_pai, axis=0)
    cc_obsv = np.mean(params.cc_obsv, axis=0)
    cc_bayes = np.mean(params.cc_bayes, axis=0)
    cc_ev = np.mean(params.cc_ev, axis=0)
    cutoffs = [cc_pai, cc_obsv, cc_bayes, cc_ev]

    for sub in trials_sub:
        split_trials = Data.split_session(trials_sub[sub])
        value_trials = np.concatenate((split_trials[3], split_trials[4]))
        value_confs[int(sub) - 1] = Environment.get_real_conf(value_trials)
        sigz = sigzs[int(sub) - 1]
        sigz_sub = sigzs_subs[int(sub) - 1]

        for conf_type in range(4):
            cutoff = cutoffs[conf_type][(int(sub) - 1)]
            if conf_type == 0:
                value_confs_pred[int(sub) - 1, conf_type] = Environment.get_pred_conf(value_trials, cutoff, sigz, sigz_sub, beta, bias=0, mult = 100)
            else:
                value_confs_pred[int(sub) - 1, conf_type] = Environment.get_pred_conf_other(value_trials, cutoff, sigz, sigz_sub, conf_type, bias=0, mult = 100)

    names = ['Experimental', 'Decision', 'Observation', 'Posterior', 'Expected Value']
    latex_string(value_confs, names[0])
    for i in range(4):
        latex_string(value_confs_pred[:, i], names[i+1])





    """overall_confs = np.zeros(num_sub)
    all_confs = np.zeros((num_sub, num_session, 2))
    for sub in trials_sub:
        overall_confs[int(sub) - 1] = Environment.get_real_conf(trials_sub[sub])
        split_trials = Data.split_session_direction(trials_sub[sub])
        for i in range(num_session):
            right_conf = Environment.get_real_conf(split_trials[f'{i}r'])
            left_conf = Environment.get_real_conf(split_trials[f'{i}l'])
            all_confs[int(sub) - 1, i, :] = right_conf, left_conf

    with open(f'results/ high_conf_rates.txt', 'w') as f:
        f.write('overall rates of high confidence for subjects \n')
        f.write(str(overall_confs))
        f.write('\n rates of high confidence for all subjects and sessions/directions \n')
        f.write(str(all_confs))

    print('overall rates of high confidence for subjects')
    print(overall_confs)
    print('rates of high confidence for all subjects and sessions/directions')
    print(all_confs)

    with open('results/other/ overall_high_conf.pkl', 'wb') as handle:
        pickle.dump(overall_confs, handle, protocol=pickle.HIGHEST_PROTOCOL)"""